{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Pytorch Implementation of InfoGAN using MNIST dataset\n",
    "# code by GunhoChoi\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils as utils\n",
    "import torch.nn.init as init\n",
    "from torch.autograd import Variable\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.utils as v_utils\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Set Hyperparameters\n",
    "\n",
    "epoch = 1000\n",
    "batch_size = 200\n",
    "learning_rate = 0.001\n",
    "num_gpus = 1\n",
    "discrete_latent_size = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded\n"
     ]
    }
   ],
   "source": [
    "# Download Data & Set Data Loader(input pipeline)\n",
    "\n",
    "mnist_train = dset.MNIST(\"./\", train=True, transform=transforms.ToTensor(),\\\n",
    "                         target_transform=None, download=True)\n",
    "train_loader = torch.utils.data.DataLoader(dataset=mnist_train,batch_size=batch_size,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define integer to onehot vector function\n",
    "# 3 -> [0 0 0 1 0 0 0 0 0 0]\n",
    "\n",
    "def int_to_onehot(z_label):\n",
    "    one_hot_array = np.zeros(shape=[len(z_label), discrete_latent_size])\n",
    "    one_hot_array[np.arange(len(z_label)), z_label] = 1\n",
    "    return one_hot_array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define Generator \n",
    "# [z + one-hot vector] -> [28 x 28 x 1]\n",
    "# z is vector of 100 random numbers(0~1) \n",
    "# one-hot vector size is 10\n",
    "\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Generator,self).__init__()\n",
    "        self.layer1 = nn.Linear(110,7*7*256)                    # [batch,110]->[batch,7*7*256]\n",
    "        self.layer2 = nn.Sequential(OrderedDict([\n",
    "                ('conv1', nn.ConvTranspose2d(256,128,3,2,1,1)), # [batch,256,7,7]->[batch,128,14,14]\n",
    "                ('relu1', nn.LeakyReLU()),\n",
    "                ('bn1', nn.BatchNorm2d(128)),\n",
    "                ('conv2', nn.ConvTranspose2d(128,64,3,1,1)),    # [batch,128,14,14]->[batch,64,14,14]\n",
    "                ('relu2', nn.LeakyReLU()),\n",
    "                ('bn2', nn.BatchNorm2d(64))\n",
    "            ]))\n",
    "        self.layer3 = nn.Sequential(OrderedDict([\n",
    "                ('conv3',nn.ConvTranspose2d(64,16,3,1,1)),      # [batch,64,14,14]->[batch,16,14,14]\n",
    "                ('relu3',nn.LeakyReLU()),\n",
    "                ('bn3',nn.BatchNorm2d(16)),\n",
    "                ('conv4',nn.ConvTranspose2d(16,1,3,2,1,1)),     # [batch,16,14,14]->[batch,1,28,28]\n",
    "                ('relu4',nn.LeakyReLU())\n",
    "            ]))\n",
    "\n",
    "    def forward(self,z):\n",
    "        out = self.layer1(z)\n",
    "        out = out.view(batch_size//num_gpus,256,7,7)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define Discriminator \n",
    "# [28 x 28 x 1] -> output variable(real data or not), one-hot vector(label)  \n",
    "# output variable is one number(0~1)\n",
    "# one-hot vector size is 10\n",
    "\n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Discriminator,self).__init__()\n",
    "        self.layer1 = nn.Sequential(OrderedDict([\n",
    "                ('conv1',nn.Conv2d(1,16,3,padding=1)),   # [batch,1,28,28]  -> [batch,16,28,28]\n",
    "                ('relu1',nn.LeakyReLU()),\n",
    "                ('bn1',nn.BatchNorm2d(16)),\n",
    "                ('conv2',nn.Conv2d(16,32,3,padding=1)),  # [batch,16,28,28] -> [batch,32,28,28]\n",
    "                ('relu2',nn.LeakyReLU()),\n",
    "                ('bn2',nn.BatchNorm2d(32)),\n",
    "                ('max1',nn.MaxPool2d(2,2))               # [batch,32,28,28] -> [batch,32,14,14]\n",
    "            ]))\n",
    "        self.layer2 = nn.Sequential(OrderedDict([\n",
    "                ('conv3',nn.Conv2d(32,64,3,padding=1)),  # [batch,32,14,14] -> [batch,64,14,14] \n",
    "                ('relu3',nn.LeakyReLU()),\n",
    "                ('bn3',nn.BatchNorm2d(64)),\n",
    "                ('max2',nn.MaxPool2d(2,2)),              # [batch,64,14,14] -> [batch,64,7,7]\n",
    "                ('conv4',nn.Conv2d(64,128,3,padding=1)), # [batch,64,7,7]   -> [batch,128,7,7]\n",
    "                ('relu4',nn.LeakyReLU())\n",
    "            ]))\n",
    "        self.fc = nn.Sequential(\n",
    "                nn.Linear(128*7*7,1),\n",
    "                nn.Sigmoid()\n",
    "            )\n",
    "        self.fc2 = nn.Sequential(\n",
    "                    nn.Linear(128*7*7,10),\n",
    "                    nn.LeakyReLU()\n",
    "            )\n",
    "\n",
    "    def forward(self,x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = out.view(batch_size//num_gpus, -1)\n",
    "        output = self.fc(out)\n",
    "        label = self.fc2(out)\n",
    "        return output,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# put class instance on multi gpu\n",
    "\n",
    "generator = nn.DataParallel(Generator()).cuda()\n",
    "discriminator = nn.DataParallel(Discriminator()).cuda()\n",
    "\n",
    "# put labels on multi gpu\n",
    "\n",
    "ones_label = Variable(torch.ones(batch_size,1)).cuda()\n",
    "zeros_label = Variable(torch.zeros(batch_size,1)).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# loss function and optimizer \n",
    "# this time, use LSGAN loss(https://arxiv.org/abs/1611.04076v2)\n",
    "\n",
    "loss_func = nn.MSELoss()\n",
    "gen_optim = torch.optim.Adam(generator.parameters(), lr=learning_rate)\n",
    "dis_optim = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--------model restored--------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# model restore\n",
    "\n",
    "try:\n",
    "    generator, discriminator = torch.load('./model/infogan.pkl')\n",
    "    print(\"\\n--------model restored--------\\n\")\n",
    "except:\n",
    "    print(\"\\n--------model not restored--------\\n\")\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type Generator. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n",
      "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type Discriminator. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0th iteration gen_loss: \n",
      " 1.9374\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      " dis_loss: \n",
      " 0.1373\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "\n",
      "1th iteration gen_loss: \n",
      " 1.9517\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      " dis_loss: \n",
      "1.00000e-02 *\n",
      "  9.8031\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "\n",
      "2th iteration gen_loss: \n",
      " 1.9305\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      " dis_loss: \n",
      "1.00000e-02 *\n",
      "  7.9263\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# train \n",
    "\n",
    "for i in range(epoch):\n",
    "    for j,(image,label) in enumerate(train_loader):\n",
    "        \n",
    "        # put image & label on gpu\n",
    "        \n",
    "        image = Variable(image).cuda()\n",
    "        label = torch.from_numpy(int_to_onehot(label.numpy()))\n",
    "        label = Variable(label.type_as(torch.FloatTensor())).cuda()\n",
    "    \n",
    "        # generator \n",
    "        \n",
    "        for k in range(2):\n",
    "            z_random = np.random.rand(batch_size,100)\n",
    "            z_label = np.random.randint(0, 10, size=batch_size)\n",
    "            z_label_onehot = int_to_onehot(z_label)\n",
    "            \n",
    "            # change first 10 labels from random to 0~9          \n",
    "            for l in range(10):\n",
    "                z_label[l]=l\n",
    "            \n",
    "            # preprocess z\n",
    "            z = np.concatenate([z_random, z_label_onehot], axis=1)\n",
    "            z = torch.from_numpy(z).type_as(torch.FloatTensor())\n",
    "            z = Variable(z).cuda()\n",
    "            z_label_onehot = torch.from_numpy(z_label_onehot).type_as(torch.FloatTensor())\n",
    "            z_label_onehot = Variable(z_label_onehot).cuda()\n",
    "\n",
    "            # calculate loss and apply gradients\n",
    "            # gen_loss = gan loss(fake) + categorical loss\n",
    "            gen_optim.zero_grad()\n",
    "            gen_fake = generator.forward(z)\n",
    "            dis_fake,label_fake = discriminator.forward(gen_fake)\n",
    "            gen_loss = torch.sum(loss_func(dis_fake,ones_label)) \\\n",
    "                      + discrete_latent_size * torch.sum(loss_func(label_fake,z_label_onehot))\n",
    "            gen_loss.backward(retain_variables=True)\n",
    "            gen_optim.step()\n",
    "\n",
    "        # discriminator\n",
    "        # dis_loss = gan_loss(fake & real) + categorical loss\n",
    "        dis_optim.zero_grad()\n",
    "        dis_real, label_real = discriminator.forward(image)\n",
    "        dis_loss = torch.sum(loss_func(dis_fake,zeros_label)) \\\n",
    "                  + torch.sum(loss_func(dis_real,ones_label)) \\\n",
    "                  + discrete_latent_size * torch.sum(loss_func(label_real,label))\n",
    "        dis_loss.backward()\n",
    "        dis_optim.step()\n",
    "    \n",
    "    # model save\n",
    "    if i % 5 == 0:\n",
    "        torch.save([generator,discriminator],'./model/infogan.pkl')\n",
    "\n",
    "    # print loss and image save\n",
    "    print(\"{}th iteration gen_loss: {} dis_loss: {}\".format(i,gen_loss.data,dis_loss.data))\n",
    "    v_utils.save_image(gen_fake.data[0:20],\"./result/gen_{}.png\".format(i), nrow=5)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
